{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from pprint import pprint\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq\n",
    "\n",
    "from oumi.datasets import AlpacaDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Listing Supported Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.registry import REGISTRY, RegistryType\n",
    "\n",
    "\n",
    "def list_datasets():\n",
    "    \"\"\"List all datasets in the registry.\"\"\"\n",
    "    for key, value in REGISTRY._registry.items():\n",
    "        if key.registry_type == RegistryType.DATASET:\n",
    "            print(key.name, \"->\", value.__name__)\n",
    "\n",
    "\n",
    "list_datasets()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Loading Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the alpaca dataset. Since multiple variants can be registered in the HuggingFace hub, by default we use `Dataset.default`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-3-mini-4k-instruct\")\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "dataset = AlpacaDataset(tokenizer=tokenizer)\n",
    "\n",
    "print(f\"Using: {dataset.dataset_name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Alternatively, you can pass a custom HuggingFace hub identifier. You can find a list of supported datasets in `Dataset.supported_datasets`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "dataset = AlpacaDataset(dataset_name=\"yahma/alpaca-cleaned\", tokenizer=tokenizer)\n",
    "\n",
    "print(f\"Using: {dataset.dataset_name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Under the hood, the dataset is downloaded from the HuggingFace hub, and cached in the `~/.cache/huggingface/datasets` directory.\n",
    "\n",
    "When instantiating the class, the dataset is loaded in memory. This is acceptable with small datasets, but for larger datasets, we can either use `IterableDataset` for streaming batch from disk, or shard per worker rank (so that Memory // N_GPUs)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Iterating Over Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of examples: {len(dataset)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given everything is loaded into memory, we can randomly access any row in the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "dataset[42]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can iterate over the dataset to get the examples, either manually or using a DataLoader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# Manual iteration\n",
    "[dataset[i] for i in range(len(dataset))];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# With a pytorch data loader\n",
    "loader = DataLoader(\n",
    "    dataset, batch_size=1, num_workers=0, shuffle=False, collate_fn=lambda x: x\n",
    ")\n",
    "list(loader);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also use any library from the pytorch ecosystem, e.g. `torchtext`, `torchdata`, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchdata.stateful_dataloader import StatefulDataLoader\n",
    "\n",
    "loader = StatefulDataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "next(iter(loader));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Accessing Individual Examples\n",
    "\n",
    "Alpaca is a \"Supervised Finetuning Dataset\". It contains instructions, user inputs, and model outputs. An SFT dataset has the following methods:\n",
    "\n",
    "**Base Map Dataset**\n",
    "```python\n",
    "dataset[0] -> model inputs  # pytorch convention\n",
    "dataset.raw(0) -> raw data  # oumi convention\n",
    "```\n",
    "**Base SFT Dataset**\n",
    "```python\n",
    "dataset.conversation(0) -> conversation # model independent\n",
    "dataset.prompt(0) -> prompt  # dependends on the tokenizer\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Raw data (pd.Series, dict), as defined by the dataset authors\n",
    "dataset.raw(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to standard Oumi SFT format (oumi.core.types.Conversation)\n",
    "dataset.conversation(0).messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to model prompt (str). This include the model's chat temlate,\n",
    "# EOS tokens for inference/generation, etc.\n",
    "dataset.prompt(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# What is fed to the model.forward (dict)\n",
    "dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Accessing the underlying data backend (for debugging only)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We currently use pd.DataFrame as a backend for the dataset. We can trivially use either HuggingFace `datasets`, or an `arrow` Table as a backend for the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.data.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install -U -q matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.data[[\"instruction\", \"output\", \"input\"]].map(len).hist();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6. Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To customize preprocessing behavior, we can override the appropriate method:\n",
    "- `BaseMapDataset.__getitem__` for fully custom behavior, multi-modal, multi-task, etc.\n",
    "- `BaseMapDataset.transform` for custom preprocessing of standard datasets (inputs, labels)\n",
    "- `BaseSftDataset.transform` To customize preprocessing, tokenization of a standard SFT dataset\n",
    "- `BaseSftDataset.transform_conversation` To transform raw data row into oumi conversation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = AlpacaDataset(tokenizer=tokenizer)\n",
    "\n",
    "print(f\"Using: {dataset.dataset_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "dataset.raw(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "dataset.raw(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "dataset.conversation(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "dataset.prompt(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "dataset[idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Comparing with other datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_dataset = datasets.load_dataset(\"tatsu-lab/alpaca\")\n",
    "hf_dataset = hf_dataset[\"train\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Raw Random Access"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "dataset.raw(idx)  # use oumi dataset random access"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "hf_dataset[idx]  # use huggingface dataset random access"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "hf_dataset.data[\"text\"][idx].as_py()  # directly access the arrow table."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### With Tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_dataset.data[\"text\"][0].as_py()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "# Arrow\n",
    "text = hf_dataset.data[\"text\"][0].as_py()\n",
    "conversation = {\"messages\": [{\"role\": \"user\", \"content\": text}]}\n",
    "dataset.tokenize(conversation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "# HF Datasets\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "dataset.tokenize(dataset.transform_conversation(hf_dataset[idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "# Oumi dataset\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "data = dataset.raw(idx).to_dict()\n",
    "dataset.tokenize(dataset.transform_conversation(data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "# Full Oumi pipeline\n",
    "idx = random.randint(0, len(dataset) - 1)\n",
    "dataset[idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Iterate over dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "# Manual iteration\n",
    "[dataset[i] for i in range(len(dataset))];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "# With a pytorch data loader\n",
    "loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False)\n",
    "list(loader);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install -U -q line_profiler memory_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext line_profiler\n",
    "%load_ext memory_profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%prun [dataset[i] for i in range(len(dataset))];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%lprun -f dataset.__getitem__ [dataset[i] for i in range(len(dataset))];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%memit [dataset[i] for i in range(len(dataset))]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Benchmark with Model Forward Pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "dataset = AlpacaDataset(tokenizer=tokenizer)\n",
    "collator_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, return_tensors=\"pt\")\n",
    "\n",
    "loader = DataLoader(\n",
    "    dataset, batch_size=3, num_workers=0, shuffle=False, collate_fn=collator_fn\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "# Including pre-processing\n",
    "with torch.no_grad():\n",
    "    batch = next(iter(loader))\n",
    "    model.forward(**batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_batch = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%timeit\n",
    "\n",
    "# Excluding pre-processing\n",
    "with torch.no_grad():\n",
    "    model.forward(**fixed_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Testing Different Tokenizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "phi_tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-3-mini-4k-instruct\")\n",
    "gpt_tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "gpt_tokenizer.pad_token = gpt_tokenizer.eos_token\n",
    "gpt_tokenizer.chat_template = phi_tokenizer.chat_template\n",
    "llama_tokenizer = None\n",
    "try:\n",
    "    llama_tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n",
    "except Exception as e:\n",
    "    print(\"Skipping Llama 3.2, since you don't have access to it.\")\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = AlpacaDataset()\n",
    "\n",
    "pprint(dataset.conversation(0).messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset._tokenizer = gpt_tokenizer\n",
    "pprint(dataset.prompt(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset._tokenizer = phi_tokenizer\n",
    "pprint(dataset.prompt(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if llama_tokenizer is not None:\n",
    "    dataset._tokenizer = llama_tokenizer\n",
    "    pprint(dataset.prompt(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. Test Packing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/Phi-3-mini-4k-instruct\")\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "dataset = AlpacaDataset(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.datasets.pretraining_async_text_dataset import (\n",
    "    PretrainingAsyncTextDataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = PretrainingAsyncTextDataset(\n",
    "    tokenizer,\n",
    "    dataset,\n",
    "    dataset_text_field=None,\n",
    "    formatting_func=lambda x: x,\n",
    "    pretokenized=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "next(iter(dataset))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "oumi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
